Add cuBLAS mm_out shim to eliminate libtorch runtime dependency#19360
Add cuBLAS mm_out shim to eliminate libtorch runtime dependency#19360digantdesai wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19360
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 1 Cancelled Job, 4 Unrelated FailuresAs of commit b6b4ad7 with merge base 8ae05c2 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Implements aoti_torch_cuda_mm_out as a thin cuBLAS wrapper in the ExecuTorch AOTI CUDA shims. When Inductor picks cuBLAS over Triton templates for aten::mm (F.linear), the compiled .so requires this symbol at runtime. Without this shim, it resolves from libtorch_cuda.so, pulling in the full libtorch runtime. In practice, Inductor's autotune on A100 picks Triton templates for the Qwen3.5 MoE dense projections (bf16 [M,2048]x[2048,N]), so the shim is not exercised for this model. It serves as a safety net for models or shapes where cuBLAS wins the autotune, ensuring fully libtorch-free AOTI CUDA deployment in all cases. Co-authored-by: Claude <noreplyanthropic.com>
7316ecf to
b6b4ad7
Compare
|
can you help me to update the title and summary a little bit? one thing is our cuda backend never depend on libtorch; our current state sounds like we are depending on it. |
Done. If we are using F.linear in the kernel, which can fallback on cuda, we need to add this even if qwen happen to not use it. |
thanks. Mind update the title as well? Also:
|
There was a problem hiding this comment.
Pull request overview
This PR adds an ExecuTorch AOTI CUDA shim implementation of aoti_torch_cuda_mm_out backed by cuBLAS to avoid requiring libtorch_cuda.so at runtime when Inductor emits calls to that symbol (e.g., for aten::mm / F.linear).
Changes:
- Introduces a cuBLAS-based
aoti_torch_cuda_mm_outshim (mm.h/mm.cu) and links cuBLAS intoaoti_cuda_shims. - Adds a new GTest suite covering correctness (bf16/fp16/fp32) and contract validation for
mm_out. - Registers the new test in both CMake and Buck/Bazel test target definitions.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cuda/runtime/shims/mm.h | Declares the exported aoti_torch_cuda_mm_out shim API. |
| backends/cuda/runtime/shims/mm.cu | Implements aoti_torch_cuda_mm_out via cublasGemmEx with per-device handle management. |
| backends/cuda/CMakeLists.txt | Adds mm.cu to shim sources and links CUDA::cublas into aoti_cuda_shims. |
| backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_mm_out.cpp | Adds typed correctness tests and contract-validation tests for the new shim. |
| backends/cuda/runtime/shims/tests/CMakeLists.txt | Builds and registers the new mm_out test binary. |
| backends/cuda/runtime/shims/tests/targets.bzl | Registers the new mm_out test in shared Buck/Bazel target definitions. |
Comments suppressed due to low confidence (1)
backends/cuda/runtime/shims/tests/targets.bzl:46
- This adds a Buck test target for
aoti_torch_cuda_mm_out, but//executorch/backends/cuda/runtime:runtime_shims(a dependency ofcuda_shim_cpp_unittest) does not currently listshims/mm.cuinsrcsorshims/mm.hinheaders. As a result, the test will fail to compile/link under Buck. Please add the new shim source/header to theruntime_shimslibrary (and any needed CUDA/cublas deps there) or adjust the test deps to a target that exports these files.
cuda_shim_cpp_unittest("aoti_torch_cuda_rand")
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
cuda_shim_cpp_unittest("aoti_torch_cuda_mm_out")
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Per-device handle; mutex in get() ensures thread-safe initialization. | ||
| // cublasSetStream + cublasGemmEx are serialized under the same mutex to | ||
| // prevent races when multiple threads share a device. | ||
| auto& handles = cublas_handles(); | ||
| std::lock_guard<std::mutex> lock(handles.mutex); | ||
| cublasHandle_t handle = handles.get(device); | ||
| cublasSetStream(handle, stream_result.get()); |
| * @param self Input matrix [M, K]. Must be bf16 or fp16, 2D, contiguous. | ||
| * @param mat2 Input matrix [K, N]. Must be bf16 or fp16, 2D, contiguous. |
| # mm_out test — cuBLAS is already linked into aoti_cuda_shims | ||
| add_executable(test_aoti_torch_cuda_mm_out test_aoti_torch_cuda_mm_out.cpp) | ||
|
|
||
| target_include_directories( | ||
| test_aoti_torch_cuda_mm_out PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} | ||
| ${CUDAToolkit_INCLUDE_DIRS} | ||
| ) | ||
|
|
||
| target_compile_definitions(test_aoti_torch_cuda_mm_out PRIVATE CUDA_AVAILABLE=1) | ||
|
|
||
| target_link_libraries( | ||
| test_aoti_torch_cuda_mm_out | ||
| PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims executorch_core | ||
| CUDA::cudart | ||
| ) |
| TEST_F(AOTITorchMmOutTest, NonContiguousRejected) { | ||
| // Create a [8, 8] tensor and slice rows to get non-contiguous [4, 8] | ||
| int64_t big_sizes[] = {8, 8}; | ||
| int64_t big_strides[] = {8, 1}; | ||
| Tensor* big = nullptr; | ||
| aoti_torch_empty_strided( | ||
| 2, | ||
| big_sizes, | ||
| big_strides, | ||
| static_cast<int32_t>(slim_c10::ScalarType::Float), | ||
| static_cast<int32_t>(slim_c10::DeviceType::CUDA), | ||
| 0, | ||
| &big); |
Implements aoti_torch_cuda_mm_out as a thin cuBLAS wrapper in the ExecuTorch AOTI CUDA shims. When Inductor picks cuBLAS over Triton templates for aten::mm (F.linear), the compiled .so requires this symbol at runtime.
In practice, Inductor's autotune on A100 picks Triton templates for the Qwen3.5 MoE dense projections (bf16 [M,2048]x[2048,N]), so the shim is not exercised for this model. It serves as a safety net for models or shapes where cuBLAS wins the autotune.